@@ -9,7 +9,7 @@ using SuiteSparse
99 using Base: @assume_effects
1010else
1111 macro assume_effects (_, ex)
12- Base. @pure ex
12+ :( Base. @pure $ (ex))
1313 end
1414end
1515
@@ -22,6 +22,72 @@ const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
2222const UpTri{T,M} = Union{UpperTriangular{T,M},UnitUpperTriangular{T,M}}
2323const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
2424
25+ """
26+ ArrayInterfaceCore.map_tuple_type(f, T::Type{<:Tuple})
27+
28+ Returns tuple where each field corresponds to the field type of `T` modified by the function `f`.
29+
30+ # Examples
31+
32+ ```julia
33+ julia> ArrayInterfaceCore.map_tuple_type(sqrt, Tuple{1,4,16})
34+ (1.0, 2.0, 4.0)
35+
36+ ```
37+ """
38+ function map_tuple_type (f:: F , :: Type{T} ) where {F,T<: Tuple }
39+ if @generated
40+ t = Expr (:tuple )
41+ for i in 1 : fieldcount (T)
42+ push! (t. args, :(f ($ (fieldtype (T, i)))))
43+ end
44+ Expr (:block , Expr (:meta , :inline ), t)
45+ else
46+ Tuple (f (fieldtype (T, i)) for i in 1 : fieldcount (T))
47+ end
48+ end
49+
50+ """
51+ ArrayInterfaceCore.flatten_tuples(t::Tuple) -> Tuple
52+
53+ Flattens any field of `t` that is a tuple. Only direct fields of `t` may be flattened.
54+
55+ # Examples
56+
57+ ```julia
58+ julia> ArrayInterfaceCore.flatten_tuples((1, ()))
59+ (1,)
60+
61+ julia> ArrayInterfaceCore.flatten_tuples((1, (2, 3)))
62+ (1, 2, 3)
63+
64+ julia> ArrayInterfaceCore.flatten_tuples((1, (2, (3,))))
65+ (1, 2, (3,))
66+
67+ ```
68+ """
69+ @inline function flatten_tuples (t:: Tuple )
70+ if @generated
71+ texpr = Expr (:tuple )
72+ for i in 1 : fieldcount (t)
73+ p = fieldtype (t, i)
74+ if p <: Tuple
75+ for j in 1 : fieldcount (p)
76+ push! (texpr. args, :(@inbounds (getfield (getfield (t, $ i), $ j))))
77+ end
78+ else
79+ push! (texpr. args, :(@inbounds (getfield (t, $ i))))
80+ end
81+ end
82+ Expr (:block , Expr (:meta , :inline ), texpr)
83+ else
84+ _flatten (t)
85+ end
86+ end
87+ _flatten (:: Tuple{} ) = ()
88+ @inline _flatten (t:: Tuple{Any,Vararg{Any}} ) = (getfield (t, 1 ), _flatten (Base. tail (t))... )
89+ @inline _flatten (t:: Tuple{Tuple,Vararg{Any}} ) = (getfield (t, 1 )... , _flatten (Base. tail (t))... )
90+
2591"""
2692 parent_type(::Type{T}) -> Type
2793
@@ -591,32 +657,100 @@ indexing with an instance of `I`.
591657"""
592658ndims_shape (T:: DataType ) = ndims_index (T)
593659ndims_shape (:: Type{Colon} ) = 1
594- ndims_shape (T:: Type{<:Base.AbstractCartesianIndex{N}} ) where {N} = ntuple (zero, Val {N} () )
595- ndims_shape (@nospecialize T:: Type{<:CartesianIndices} ) = ntuple (one, Val {ndims(T)} ())
596- ndims_shape (@nospecialize T:: Type{<:Number} ) = 0
660+ ndims_shape (@nospecialize T:: Type{<:CartesianIndices} ) = ndims (T )
661+ ndims_shape (@nospecialize T:: Type{<:Union{Number,Base.AbstractCartesianIndex}} ) = 0
662+ ndims_shape (@nospecialize T:: Type{<:AbstractArray{Bool}} ) = 1
597663ndims_shape (@nospecialize T:: Type{<:AbstractArray} ) = ndims (T)
598664ndims_shape (x) = ndims_shape (typeof (x))
599665
666+ @assume_effects :total function _find_first_true (isi:: Tuple{Vararg{Bool,N}} ) where {N}
667+ for i in 1 : N
668+ getfield (isi, i) && return i
669+ end
670+ return nothing
671+ end
672+
600673"""
601- IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{NI,NS,IS }()
674+ IndicesInfo{N} (T::Type{<:Tuple}) -> IndicesInfo{N, NI,NS}()
602675
603676Provides basic trait information for each index type in in the tuple `T`. `NI`, `NS`, and
604677`IS` are tuples of [`ndims_index`](@ref), [`ndims_shape`](@ref), and
605678[`is_splat_index`](@ref) (respectively) for each field of `T`.
679+
680+ # Examples
681+
682+ ```julia
683+ julia> using ArrayInterfaceCore: IndicesInfo
684+
685+ julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
686+ IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
687+
688+ ```
606689"""
607- struct IndicesInfo{NI,NS,IS} end
608- IndicesInfo (@nospecialize x:: Tuple ) = IndicesInfo (typeof (x))
609- @generated function IndicesInfo (:: Type{T} ) where {T<: Tuple }
610- NI = Expr (:tuple )
611- NS = Expr (:tuple )
612- IS = Expr (:tuple )
613- for i in 1 : fieldcount (T)
614- T_i = fieldtype (T, i)
615- push! (NI. args, :(ndims_index ($ (T_i))))
616- push! (NS. args, :(ndims_shape ($ (T_i))))
617- push! (IS. args, :(is_splat_index ($ (T_i))))
690+ struct IndicesInfo{N,NI,NS} end
691+ IndicesInfo (x:: SubArray ) = IndicesInfo {ndims(parent(x))} (typeof (x. indices))
692+ @inline function IndicesInfo (@nospecialize T:: Type{<:SubArray} )
693+ IndicesInfo {ndims(parent_type(T))} (fieldtype (T, :indices ))
694+ end
695+ function IndicesInfo {N} (@nospecialize (T:: Type{<:Tuple} )) where {N}
696+ _indices_info (
697+ Val {_find_first_true(map_tuple_type(is_splat_index, T))} (),
698+ IndicesInfo {N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)} ()
699+ )
700+ end
701+ function _indices_info (:: Val{nothing} , :: IndicesInfo{1,(1,),NS} ) where {NS}
702+ ns1 = getfield (NS, 1 )
703+ IndicesInfo {1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
704+ end
705+ function _indices_info (:: Val{nothing} , :: IndicesInfo{N,(1,),NS} ) where {N,NS}
706+ ns1 = getfield (NS, 1 )
707+ IndicesInfo {N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
708+ end
709+ @inline function _indices_info (:: Val{nothing} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS}
710+ if sum (NI) > N
711+ IndicesInfo {N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)} ()
712+ else
713+ IndicesInfo {N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)} ()
714+ end
715+ end
716+ @inline function _indices_info (:: Val{SI} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS,SI}
717+ nsplat = N - sum (NI)
718+ if nsplat === 0
719+ _indices_info (Val {nothing} (), IndicesInfo {N,NI,NS} ())
720+ else
721+ splatmul = max (0 , nsplat + 1 )
722+ _indices_info (Val {nothing} (), IndicesInfo {N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)} ())
723+ end
724+ end
725+ @inline function _map_splats (nsplat:: Int , splat_index:: Int , dims:: Tuple{Vararg{Int}} )
726+ ntuple (length (dims)) do i
727+ i === splat_index ? (nsplat * getfield (dims, i)) : getfield (dims, i)
728+ end
729+ end
730+ @inline function _replace_trailing (n:: Int , dims:: Tuple{Vararg{Any,N}} ) where {N}
731+ ntuple (N) do i
732+ dim_i = getfield (dims, i)
733+ if dim_i isa Tuple
734+ ntuple (length (dim_i)) do j
735+ dim_i_j = getfield (dim_i, j)
736+ dim_i_j > n ? 0 : dim_i_j
737+ end
738+ else
739+ dim_i > n ? 0 : dim_i
740+ end
741+ end
742+ end
743+ @inline function _accum_dims (csdims:: NTuple{N,Int} , nd:: NTuple{N,Int} ) where {N}
744+ ntuple (N) do i
745+ nd_i = getfield (nd, i)
746+ if nd_i === 0
747+ 0
748+ elseif nd_i === 1
749+ getfield (csdims, i)
750+ else
751+ ntuple (Base. Fix1 (+ , getfield (csdims, i) - nd_i), nd_i)
752+ end
618753 end
619- Expr (:block , Expr (:meta , :inline ), :(IndicesInfo {$(NI),$(NS),$(IS)} ()))
620754end
621755
622756"""
0 commit comments