@@ -806,74 +806,122 @@ ndims_shape(x) = ndims_shape(typeof(x))
806806end
807807
808808"""
809+ IndicesInfo{N}(inds::Tuple) -> IndicesInfo{N}(typeof(inds))
809810 IndicesInfo{N}(T::Type{<:Tuple}) -> IndicesInfo{N,pdims,cdims}()
811+ IndicesInfo(inds::Tuple) -> IndicesInfo(typeof(inds))
812+ IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{maximum(pdims),pdims,cdims}()
810813
811- Provides basic trait information for each index type in in the tuple `T`. `pdims` and
812- `cdims` are dimension mappings to the parent and child dimensions respectively.
814+
815+ Maps a tuple of indices to `N` dimensions. The resulting `pdims` is a tuple where each
816+ field in `inds` (or field type in `T`) corresponds to the parent dimensions accessed.
817+ `cdims` similarly maps indices to the resulting child array produced after indexing with
818+ `inds`. If `N` is not provided then it is assumed that all indices are represented by parent
819+ dimensions and there are no trailing dimensions accessed. These may be accessed by through
820+ `parentdims(info::IndicesInfo)` and `childdims(info::IndicesInfo)`. If `N` is not provided,
821+ it is assumed that no indices are accessing trailing dimensions (which are represented as
822+ `0` in `parentdims(info)[index_position]`).
823+
824+ The the fields and types of `IndicesInfo` should not be accessed directly.
825+ Instead [`parentdims`](@ref), [`childdims`](@ref), [`ndims_index`](@ref), and
826+ [`ndims_shape`](@ref) should be used to extract relevant information.
813827
814828# Examples
815829
816830```julia
817- julia> using ArrayInterfaceCore: IndicesInfo
831+ julia> using ArrayInterfaceCore: IndicesInfo, parentdims, childdims, ndims_index, ndims_shape
832+
833+ julia> info = IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
834+
835+ julia> parentdims(info) # the last two indices access trailing dimensions
836+ (1, (2, 3), 4, 5, 0, 0)
837+
838+ julia> childdims(info)
839+ (1, 2, 0, (3, 4), 5, 0)
840+
841+ julia> childdims(info)[3] # index 3 accesses a parent dimension but is dropped in the child array
842+ 0
818843
819- julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
820- IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
844+ julia> ndims_index(info)
845+ 5
846+
847+ julia> ndims_shape(info)
848+ 5
849+
850+ julia> info = IndicesInfo(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
851+
852+ julia> parentdims(info) # assumed no trailing dimensions
853+ (1, (2, 3), 4, 5, 6, 7)
854+
855+ julia> ndims_index(info) # assumed no trailing dimensions
856+ 7
821857
822858```
823859"""
824- struct IndicesInfo{N,NI,NS} end
825- IndicesInfo (x:: SubArray ) = IndicesInfo {ndims(parent(x))} (typeof (x. indices))
826- @inline function IndicesInfo (@nospecialize T:: Type{<:SubArray} )
827- IndicesInfo {ndims(parent_type(T))} (fieldtype (T, :indices ))
828- end
829- function IndicesInfo {N} (@nospecialize (T:: Type{<:Tuple} )) where {N}
830- _indices_info (
831- Val {_find_first_true(map_tuple_type(is_splat_index, T))} (),
832- IndicesInfo {N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)} ()
833- )
834- end
835- function _indices_info (:: Val{nothing} , :: IndicesInfo{1,(1,),NS} ) where {NS}
836- ns1 = getfield (NS, 1 )
837- IndicesInfo {1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
838- end
839- function _indices_info (:: Val{nothing} , :: IndicesInfo{N,(1,),NS} ) where {N,NS}
840- ns1 = getfield (NS, 1 )
841- IndicesInfo {N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
842- end
843- @inline function _indices_info (:: Val{nothing} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS}
844- if sum (NI) > N
845- IndicesInfo {N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)} ()
846- else
847- IndicesInfo {N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)} ()
860+ struct IndicesInfo{Np,pdims,cdims,Nc}
861+ function IndicesInfo {N} (@nospecialize (T:: Type{<:Tuple} )) where {N}
862+ SI = _find_first_true (map_tuple_type (is_splat_index, T))
863+ NI = map_tuple_type (ndims_index, T)
864+ NS = map_tuple_type (ndims_shape, T)
865+ if SI === nothing
866+ ndi = NI
867+ nds = NS
868+ else
869+ nsplat = N - sum (NI)
870+ if nsplat === 0
871+ ndi = NI
872+ nds = NS
873+ else
874+ splatmul = max (0 , nsplat + 1 )
875+ ndi = _map_splats (splatmul, SI, NI)
876+ nds = _map_splats (splatmul, SI, NS)
877+ end
878+ end
879+ if ndi === (1 ,) && N != = 1
880+ ns1 = getfield (nds, 1 )
881+ new {N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,),ns1} ()
882+ else
883+ nds_cumsum = cumsum (nds)
884+ if sum (ndi) > N
885+ init_pdims = _accum_dims (cumsum (ndi), ndi)
886+ pdims = ntuple (nfields (init_pdims)) do i
887+ dim_i = getfield (init_pdims, i)
888+ if dim_i isa Tuple
889+ ntuple (length (dim_i)) do j
890+ dim_i_j = getfield (dim_i, j)
891+ dim_i_j > N ? 0 : dim_i_j
892+ end
893+ else
894+ dim_i > N ? 0 : dim_i
895+ end
896+ end
897+ new {N, pdims, _accum_dims(nds_cumsum, nds), last(nds_cumsum)} ()
898+ else
899+ new {N,_accum_dims(cumsum(ndi), ndi), _accum_dims(nds_cumsum, nds), last(nds_cumsum)} ()
900+ end
901+ end
848902 end
849- end
850- @inline function _indices_info (:: Val{SI} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS,SI}
851- nsplat = N - sum (NI)
852- if nsplat === 0
853- _indices_info (Val {nothing} (), IndicesInfo {N,NI,NS} ())
854- else
855- splatmul = max (0 , nsplat + 1 )
856- _indices_info (Val {nothing} (), IndicesInfo {N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)} ())
903+ IndicesInfo {N} (@nospecialize (t:: Tuple )) where {N} = IndicesInfo {N} (typeof (t))
904+ function IndicesInfo (@nospecialize (T:: Type{<:Tuple} ))
905+ ndi = map_tuple_type (ndims_index, T)
906+ nds = map_tuple_type (ndims_shape, T)
907+ ndi_sum = cumsum (ndi)
908+ nds_sum = cumsum (nds)
909+ nf = nfields (ndi_sum)
910+ pdims = _accum_dims (ndi_sum, ndi)
911+ cdims = _accum_dims (nds_sum, nds)
912+ new {getfield(ndi_sum, nf),pdims,cdims,getfield(nds_sum, nf)} ()
857913 end
914+ IndicesInfo (@nospecialize t:: Tuple ) = IndicesInfo (typeof (t))
915+ @inline function IndicesInfo (@nospecialize T:: Type{<:SubArray} )
916+ IndicesInfo {ndims(parent_type(T))} (fieldtype (T, :indices ))
917+ end
918+ IndicesInfo (x:: SubArray ) = IndicesInfo {ndims(parent(x))} (typeof (x. indices))
858919end
859920@inline function _map_splats (nsplat:: Int , splat_index:: Int , dims:: Tuple{Vararg{Int}} )
860921 ntuple (length (dims)) do i
861922 i === splat_index ? (nsplat * getfield (dims, i)) : getfield (dims, i)
862923 end
863924end
864- @inline function _replace_trailing (n:: Int , dims:: Tuple{Vararg{Any,N}} ) where {N}
865- ntuple (N) do i
866- dim_i = getfield (dims, i)
867- if dim_i isa Tuple
868- ntuple (length (dim_i)) do j
869- dim_i_j = getfield (dim_i, j)
870- dim_i_j > n ? 0 : dim_i_j
871- end
872- else
873- dim_i > n ? 0 : dim_i
874- end
875- end
876- end
877925@inline function _accum_dims (csdims:: NTuple{N,Int} , nd:: NTuple{N,Int} ) where {N}
878926 ntuple (N) do i
879927 nd_i = getfield (nd, i)
887935 end
888936end
889937
938+ _lower_info (:: IndicesInfo{Np,pdims,cdims,Nc} ) where {Np,pdims,cdims,Nc} = Np,pdims,cdims,Nc
939+
940+ ndims_index (@nospecialize (info:: IndicesInfo )) = getfield (_lower_info (info), 1 )
941+ ndims_shape (@nospecialize (info:: IndicesInfo )) = getfield (_lower_info (info), 4 )
942+
943+ """
944+ parentdims(::IndicesInfo) -> Tuple
945+
946+ Returns the parent dimension mapping from `IndicesInfo`.
947+
948+ See also: [`IndicesInfo`](@ref), [`childdims`](@ref)
949+ """
950+ parentdims (@nospecialize info:: IndicesInfo ) = getfield (_lower_info (info), 2 )
951+
952+ """
953+ childdims(::IndicesInfo) -> Tuple
954+
955+ Returns the child dimension mapping from `IndicesInfo`.
956+
957+ See also: [`IndicesInfo`](@ref), [`parentdims`](@ref)
958+ """
959+ childdims (@nospecialize info:: IndicesInfo ) = getfield (_lower_info (info), 3 )
960+
961+
890962"""
891963 instances_do_not_alias(::Type{T}) -> Bool
892964
0 commit comments