@@ -59,6 +59,14 @@ _contiguous_axis(::Any, ::Nothing) = nothing
5959 Expr (:call , Expr (:curly , :Contiguous , new_contig))
6060end
6161
62+ # contiguous_if_one(::Contiguous{1}) = Contiguous{1}()
63+ # contiguous_if_one(::Any) = Contiguous{-1}()
64+ function contiguous_axis (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} }
65+ isbitstype (S) ? Contiguous {1} () : nothing
66+ # contiguous_if_one(contiguous_axis(parent_type(R)))
67+ end
68+
69+
6270"""
6371contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
6472
@@ -102,6 +110,8 @@ _contiguous_batch_size(::Any, ::Any, ::Any) = nothing
102110 end
103111end
104112
113+ contiguous_batch_size (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} } = ContiguousBatch {0} ()
114+
105115struct StrideRank{R} end
106116Base. @pure StrideRank (R:: NTuple{<:Any,Int} ) = StrideRank {R} ()
107117_get (:: StrideRank{R} ) where {R} = R
@@ -158,6 +168,7 @@ _stride_rank(::Any, ::Any) = nothing
158168 Expr (:call , Expr (:curly , :StrideRank , ranktup))
159169end
160170stride_rank (x, i) = stride_rank (x)[i]
171+ stride_rank (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} } = StrideRank {ntuple(identity, Val{N}())} ()
161172
162173"""
163174is_column_major(A) -> Val{true/false}()
@@ -247,6 +258,9 @@ julia> ArrayInterface.size(A)
247258```
248259"""
249260size (A) = Base. size (A)
261+ size (x:: LinearAlgebra.Adjoint{T,V} ) where {T, V <: AbstractVector{T} } = (One (), static_length (x))
262+ size (x:: LinearAlgebra.Transpose{T,V} ) where {T, V <: AbstractVector{T} } = (One (), static_length (x))
263+
250264"""
251265 strides(A)
252266
@@ -257,6 +271,16 @@ julia> A = rand(3,4);
257271
258272julia> ArrayInterface.strides(A)
259273(StaticInt{1}(), 3)
274+
275+ Additionally, the behavior differs from `Base.strides` for adjoint vectors:
276+
277+ julia> x = rand(5);
278+
279+ julia> ArrayInterface.strides(x')
280+ (StaticInt{1}(), StaticInt{1}())
281+
282+ This is to support the pattern of using just the first stride for linear indexing, `x[i]`,
283+ while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`.
260284```
261285"""
262286strides (A) = Base. strides (A)
@@ -272,14 +296,24 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
272296@inline strides (A:: Vector{<:Any} ) = (StaticInt (1 ),)
273297@inline strides (A:: Array{<:Any,N} ) where {N} = (StaticInt (1 ), Base. tail (Base. strides (A))... )
274298@inline strides (A:: AbstractArray ) = _strides (A, Base. strides (A), contiguous_axis (A))
299+
300+ @inline function strides (x:: LinearAlgebra.Adjoint{T,V} ) where {T, V <: AbstractVector{T} }
301+ strd = stride (parent (x), One ())
302+ (strd, strd)
303+ end
304+ @inline function strides (x:: LinearAlgebra.Transpose{T,V} ) where {T, V <: AbstractVector{T} }
305+ strd = stride (parent (x), One ())
306+ (strd, strd)
307+ end
308+
275309@generated function _strides (A:: AbstractArray{T,N} , s:: NTuple{N} , :: Contiguous{C} ) where {T,N,C}
276310 if C ≤ 0 || C > N
277311 return Expr (:block , Expr (:meta ,:inline ), :s )
278312 end
279313 stup = Expr (:tuple )
280314 for n ∈ 1 : N
281315 if n == C
282- push! (stup. args, :(StaticInt {$(sizeof(T))} ()))
316+ push! (stup. args, :(One ()))
283317 else
284318 push! (stup. args, Expr (:ref , :s , n))
285319 end
@@ -290,6 +324,22 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
290324 end
291325end
292326
327+ if VERSION ≥ v " 1.6.0-DEV.1581"
328+ @generated function _strides (_:: Base.ReinterpretArray{T, N, S, A, true} , s:: NTuple{N} , :: Contiguous{1} ) where {T, N, S, D, A <: Array{S,D} }
329+ stup = Expr (:tuple , :(One ()))
330+ if D < N
331+ push! (stup. args, Expr (:call , Expr (:curly , :StaticInt , sizeof (S) ÷ sizeof (T))))
332+ end
333+ for n ∈ 2 + (D < N): N
334+ push! (stup. args, Expr (:ref , :s , n))
335+ end
336+ quote
337+ $ (Expr (:meta ,:inline ))
338+ @inbounds $ stup
339+ end
340+ end
341+ end
342+
293343@inline function offsets (x, i)
294344 inds = indices (x, i)
295345 start = known_first (inds)
313363@inline strides (B:: PermutedDimsArray{T,N,I1,I2,A} ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = permute (strides (parent (B)), Val {I1} ())
314364@inline stride (A:: AbstractArray , :: StaticInt{N} ) where {N} = strides (A)[N]
315365@inline stride (A:: AbstractArray , :: Val{N} ) where {N} = strides (A)[N]
316- stride (A, i) = Base. stride (A, i)
366+ stride (A, i) = Base. stride (A, i) # for type stability
317367
318368size (B:: S ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S <: SubArray{T,N,A,I} } = _size (size (parent (B)), B. indices, map (static_length, B. indices))
319369strides (B:: S ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S <: SubArray{T,N,A,I} } = _strides (strides (parent (B)), B. indices)
@@ -333,11 +383,16 @@ end
333383@generated function _strides (A:: Tuple{Vararg{Any,N}} , inds:: I ) where {N, I<: Tuple }
334384 t = Expr (:tuple )
335385 for n in 1 : N
336- if I. parameters[n] <: AbstractRange
386+ if I. parameters[n] <: AbstractUnitRange
337387 push! (t. args, Expr (:ref , :A , n))
388+ elseif I. parameters[n] <: AbstractRange
389+ push! (t. args, Expr (:call , :(* ), Expr (:ref , :A , n), Expr (:call , :static_step , Expr (:ref , :inds , n))))
338390 elseif ! (I. parameters[n] <: Integer )
339391 return nothing
340392 end
341393 end
342394 Expr (:block , Expr (:meta , :inline ), t)
343395end
396+
397+
398+
0 commit comments