@@ -262,31 +262,29 @@ stride_rank(x, i) = stride_rank(x)[i]
262262function stride_rank (:: Type{R} ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} }
263263 return nstatic (Val (N))
264264end
265- if VERSION ≥ v " 1.6.0-DEV.1581"
266- @inline function stride_rank (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
265+ @inline function stride_rank (:: Type{A} ) where {NB,NA,B<: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any,NA,<:Any,B,true} }
267266 NA == NB ? stride_rank (B) : _stride_rank_reinterpret (stride_rank (B), gt (StaticInt {NB} (), StaticInt {NA} ()))
268- end
269- @inline _stride_rank_reinterpret (sr, :: False ) = (One (), map (Base. Fix2 (+ ,One ()),sr)... )
270- @inline _stride_rank_reinterpret (sr:: Tuple{One,Vararg} , :: True ) = map (Base. Fix2 (- ,One ()), tail (sr))
271- # if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
272- # doesn't currently have a means of representing.
273- @inline function contiguous_axis (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
267+ end
268+ @inline _stride_rank_reinterpret (sr, :: False ) = (One (), map (Base. Fix2 (+ , One ()), sr)... )
269+ @inline _stride_rank_reinterpret (sr:: Tuple{One,Vararg} , :: True ) = map (Base. Fix2 (- , One ()), tail (sr))
270+ # if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
271+ # doesn't currently have a means of representing.
272+ @inline function contiguous_axis (:: Type{A} ) where {NB,NA,B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any,NA,<:Any,B, true} }
274273 _reinterpret_contiguous_axis (stride_rank (B), dense_dims (B), contiguous_axis (B), gt (StaticInt {NB} (), StaticInt {NA} ()))
275- end
276- @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: False ) = One ()
277- @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: True ) = Zero ()
278- @generated function _reinterpret_contiguous_axis (t:: Tuple{One,Vararg{StaticInt,N}} , d:: Tuple{True,Vararg{StaticBool,N}} , :: One , :: True ) where {N}
274+ end
275+ @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: False ) = One ()
276+ @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: True ) = Zero ()
277+ @generated function _reinterpret_contiguous_axis (t:: Tuple{One,Vararg{StaticInt,N}} , d:: Tuple{True,Vararg{StaticBool,N}} , :: One , :: True ) where {N}
279278 for n in 1 : N
280- if t. parameters[n+ 1 ]. parameters[1 ] === 2
281- if d. parameters[n+ 1 ] === True
282- return :(StaticInt {$n} ())
283- else
284- return :(Zero ())
279+ if t. parameters[n+ 1 ]. parameters[1 ] === 2
280+ if d. parameters[n+ 1 ] === True
281+ return :(StaticInt {$n} ())
282+ else
283+ return :(Zero ())
284+ end
285285 end
286- end
287286 end
288287 :(Zero ())
289- end
290288end
291289
292290function stride_rank (:: Type {Base. ReshapedArray{T, N, P, Tuple{Vararg{Base. SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
411409function dense_dims (:: Type{S} ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
412410 return _dense_dims (S, dense_dims (A), Val (stride_rank (A)))
413411end
414- if VERSION ≥ v " 1.6.0-DEV.1581"
415- @inline function dense_dims (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
416- ddb = dense_dims (B)
417- IfElse. ifelse (Static. le (StaticInt (NB), StaticInt (NA)), (True (), ddb... ), Base. tail (ddb))
418- end
412+ @inline function dense_dims (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
413+ ddb = dense_dims (B)
414+ IfElse. ifelse (Static. le (StaticInt (NB), StaticInt (NA)), (True (), ddb... ), Base. tail (ddb))
419415end
420416
421417_dense_dims (:: Type{S} , :: Nothing , :: Val{R} ) where {R,N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} } = nothing
@@ -561,70 +557,127 @@ strides(A::StrideIndex) = getfield(A, :strides)
561557 end
562558end
563559
564- # Fixes the example of https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
565- # TODO : Should be generalized to reshaped arrays wrapping more general array types
566- function strides (A:: ReshapedArray{T,N,P} ) where {T, N, P<: AbstractVector }
567- if defines_strides (A)
568- return size_to_strides (size (A), first (strides (parent (A))))
560+ _is_column_dense (:: A ) where {A<: AbstractArray } =
561+ defines_strides (A) &&
562+ (ndims (A) == 0 || Bool (is_dense (A)) && Bool (is_column_major (A)))
563+
564+ # Fixes the example of https://github.com/JuliaArrays/ArrayInterfaceCore.jl/issues/160
565+ function strides (A:: ReshapedArray )
566+ _is_column_dense (parent (A)) && return size_to_strides (size (A), One ())
567+ pst = strides (parent (A))
568+ psz = size (parent (A))
569+ # Try dimension merging in order (starting from dim1).
570+ # `sz1` and `st1` are the `size`/`stride` of dim1 after dimension merging.
571+ # `n` indicates the last merged dimension.
572+ # note: `st1` should be static if possible
573+ sz1, st1, n = merge_adjacent_dim (psz, pst)
574+ n == ndims (A. parent) && return size_to_strides (size (A), st1)
575+ return _reshaped_strides (size (A), One (), sz1, st1, n, Dims (psz), Dims (pst))
576+ end
577+
578+ @inline function _reshaped_strides (:: Dims{0} , reshaped, msz:: Int , _, :: Int , :: Dims , :: Dims )
579+ reshaped == msz && return ()
580+ throw (ArgumentError (" Input is not strided." ))
581+ end
582+ function _reshaped_strides (asz:: Dims , reshaped, msz:: Int , mst, n:: Int , apsz:: Dims , apst:: Dims )
583+ st = reshaped * mst
584+ reshaped = reshaped * asz[1 ]
585+ if length (asz) > 1 && reshaped == msz && asz[2 ] != 1
586+ msz, mst′, n = merge_adjacent_dim (apsz, apst, n + 1 )
587+ reshaped = 1
588+ else
589+ mst′ = Int (mst)
590+ end
591+ sts = _reshaped_strides (tail (asz), reshaped, msz, mst′, n, apsz, apst)
592+ return (st, sts... )
593+ end
594+
595+ merge_adjacent_dim (:: Tuple{} , :: Tuple{} ) = 1 , One (), 0
596+ merge_adjacent_dim (szs:: Tuple{Any} , sts:: Tuple{Any} ) = Int (szs[1 ]), sts[1 ], 1
597+ function merge_adjacent_dim (szs:: Tuple , sts:: Tuple )
598+ if szs[1 ] isa One # Just ignore dimension with size 1
599+ sz, st, n = merge_adjacent_dim (tail (szs), tail (sts))
600+ return sz, st, n + 1
601+ elseif szs[2 ] isa One # Just ignore dimension with size 1
602+ sz, st, n = merge_adjacent_dim ((szs[1 ], tail (tail (szs))... ), (sts[1 ], tail (tail (sts))... ))
603+ return sz, st, n + 1
604+ elseif (szs[1 ], szs[2 ], sts[1 ], sts[2 ]) isa NTuple{4 ,StaticInt} # the check could be done during compiling.
605+ if sts[2 ] == sts[1 ] * szs[1 ]
606+ szs′ = (szs[1 ] * szs[2 ], tail (tail (szs))... )
607+ sts′ = (sts[1 ], tail (tail (sts))... )
608+ sz, st, n = merge_adjacent_dim (szs′, sts′)
609+ return sz, st, n + 1
610+ else
611+ return Int (szs[1 ]), sts[1 ], 1
612+ end
613+ else # the check can't be done during compiling.
614+ sz, st, n = merge_adjacent_dim (Dims (szs), Dims (sts), 1 )
615+ if (szs[1 ], sts[1 ]) isa NTuple{2 ,StaticInt} && szs[1 ] != 1
616+ # But the 1st stride might still be static.
617+ return sz, sts[1 ], n
618+ else
619+ return sz, st, n
620+ end
621+ end
622+ end
623+
624+ function merge_adjacent_dim (psz:: Dims{N} , pst:: Dims{N} , n:: Int ) where {N}
625+ sz, st = psz[n], pst[n]
626+ while n < N
627+ szₙ, stₙ = psz[n+ 1 ], pst[n+ 1 ]
628+ if sz == 1
629+ sz, st = szₙ, stₙ
630+ elseif stₙ == st * sz
631+ sz *= szₙ
632+ elseif szₙ != 1
633+ break
634+ end
635+ n += 1
636+ end
637+ return sz, st, n
638+ end
639+
640+ # `strides` for `Base.ReinterpretArray`
641+ function strides (A:: Base.ReinterpretArray{T,<:Any,S,<:AbstractArray{S},IsReshaped} ) where {T,S,IsReshaped}
642+ _is_column_dense (parent (A)) && return size_to_strides (size (A), One ())
643+ stp = strides (parent (A))
644+ ET, ES = static (sizeof (T)), static (sizeof (S))
645+ ET === ES && return stp
646+ IsReshaped && ET < ES && return (One (), _reinterp_strides (stp, ET, ES)... )
647+ first (stp) == 1 || throw (ArgumentError (" Parent must be contiguous in the 1st dimension!" ))
648+ if IsReshaped
649+ # The wrapper tell us `A`'s parent has static size in dim1.
650+ # We can make the next stride static if the following dim is still dense.
651+ sr = stride_rank (parent (A))
652+ dd = dense_dims (parent (A))
653+ stp′ = _new_static (stp, sr, dd, ET ÷ ES)
654+ return _reinterp_strides (tail (stp′), ET, ES)
569655 else
570- return Base . strides (A )
656+ return ( One (), _reinterp_strides ( tail (stp), ET, ES) ... )
571657 end
572658end
573- function strides (A:: ReshapedArray{T,N,P} ) where {T, N, P}
574- if defines_strides (A)
575- return size_to_strides (size (A), static (1 ))
659+ _new_static (P,_,_,_) = P # This should never be called, just in case.
660+ @generated function _new_static (p:: P , :: SR , :: DD , :: StaticInt{S} ) where {S,N,P<: NTuple{N,Union{Int,StaticInt}} ,SR<: NTuple{N,StaticInt} ,DD<: NTuple{N,StaticBool} }
661+ sr = fieldtypes (SR)
662+ j = findfirst (T -> T () == sr[1 ]()+ 1 , sr)
663+ if ! isnothing (j) && ! (fieldtype (P, j) <: StaticInt ) && fieldtype (DD, j) === True
664+ return :(tuple ($ ((i == j ? :(static ($ S)) : :(p[$ i]) for i in 1 : N). .. )))
576665 else
577- return Base. strides (A)
578- end
579- end
580-
581-
582- @inline bmap (f:: F , t:: Tuple{} , x:: Number ) where {F} = ()
583- @inline bmap (f:: F , t:: Tuple{T} , x:: Number ) where {F, T} = (f (first (t),x), )
584- @inline bmap (f:: F , t:: Tuple , x:: Number ) where {F} = (f (first (t),x), bmap (f, Base. tail (t), x)... )
585- @static if VERSION ≥ v " 1.6.0-DEV.1581"
586- # from `reinterpret(reshape, ...)`
587- @inline function strides (A:: Base.ReinterpretArray{R, N, T, B, true} ) where {R,N,T,B}
588- P = strides (parent (A))
589- if sizeof (R) == sizeof (T)
590- P
591- elseif sizeof (R) > sizeof (T)
592- x = Base. tail (P)
593- fx = first (x)
594- if fx isa Int
595- (One (), bmap (Base. sdiv_int, Base. tail (x), fx)... )
596- else
597- (One (), bmap (÷ , Base. tail (x), fx)... )
598- end
666+ return :(p)
667+ end
668+ end
669+ @inline function _reinterp_strides (stp:: Tuple , els:: StaticInt , elp:: StaticInt )
670+ if elp % els == 0
671+ N = elp ÷ els
672+ return map (i -> N * i, stp)
599673 else
600- (One (), bmap (* , P, StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
601- end
602- end
603- # plain `reinterpret(...)`
604- @inline function strides (A:: Base.ReinterpretArray{R, N, T, B, false} ) where {R,N,T,B}
605- P = strides (parent (A))
606- if sizeof (R) == sizeof (T)
607- P
608- elseif sizeof (R) > sizeof (T)
609- (first (P), bmap (÷ , Base. tail (P), StaticInt (sizeof (R)) ÷ StaticInt (sizeof (T)))... )
610- else # sizeof(R) < sizeof(T)
611- (first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
612- end
613- end
614- else
615- # plain `reinterpret(...)`
616- @inline function strides (A:: Base.ReinterpretArray{R, N, T} ) where {R,N,T}
617- P = strides (parent (A))
618- if sizeof (R) == sizeof (T)
619- P
620- elseif sizeof (R) > sizeof (T)
621- (first (P), bmap (÷ , Base. tail (P), StaticInt (sizeof (R)) ÷ StaticInt (sizeof (T)))... )
622- else # sizeof(R) < sizeof(T)
623- (first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
624- end
625- end
626- end
627- # @inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
674+ return map (stp) do i
675+ d, r = divrem (elp * i, els)
676+ iszero (r) || throw (ArgumentError (" Parent's strides could not be exactly divided!" ))
677+ d
678+ end
679+ end
680+ end
628681
629682strides (:: AbstractRange ) = (One (),)
630683function strides (x:: VecAdjTrans )
0 commit comments