Skip to content

Commit 4cbfa4b

Browse files
committed
Further changes.
1 parent bb8df83 commit 4cbfa4b

File tree

3 files changed

+74
-15
lines changed

3 files changed

+74
-15
lines changed

src/ArrayInterface.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
2525
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
2626
parent_type(::Type{Slice{T}}) where {T} = T
2727
parent_type(::Type{T}) where {T} = T
28+
parent_type(::Type{R}) where {S, T, A <: AbstractArray{S}, N, R <: Base.ReinterpretArray{T, N, S, A}} = A
2829

2930
"""
3031
known_length(::Type{T})
@@ -880,10 +881,11 @@ function __init__()
880881
size(A::OffsetArrays.OffsetArray) = size(parent(A))
881882
strides(A::OffsetArrays.OffsetArray) = strides(parent(A))
882883
# offsets(A::OffsetArrays.OffsetArray) = map(+, A.offsets, offsets(parent(A)))
883-
device(::OffsetArrays.OffsetArray) = CheckParent()
884-
contiguous_axis(A::OffsetArrays.OffsetArray) = contiguous_axis(parent(A))
885-
contiguous_batch_size(A::OffsetArrays.OffsetArray) = contiguous_batch_size(parent(A))
886-
stride_rank(A::OffsetArrays.OffsetArray) = stride_rank(parent(A))
884+
parent_type(::Type{O}) where {T,N,A<:AbstractArray{T,N},O<:OffsetArrays.OffsetArray{T,N,A}} = A
885+
device(::Type{<:OffsetArrays.OffsetArray}) = CheckParent()
886+
contiguous_axis(::Type{A}) where {A <: OffsetArrays.OffsetArray} = contiguous_axis(parent_type(A))
887+
contiguous_batch_size(::Type{A}) where {A <: OffsetArrays.OffsetArray} = contiguous_batch_size(parent_type(A))
888+
stride_rank(::Type{A}) where {A <: OffsetArrays.OffsetArray} = stride_rank(parent_type(A))
887889
end
888890
end
889891

src/static.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@ Base.Integer(x::StaticInt{N}) where {N} = x
2222
(::Type{T})(x::StaticInt{N}) where {T<:Integer,N} = T(N)
2323
(::Type{T})(x::Int) where {T<:StaticInt} = StaticInt(x)
2424
Base.convert(::Type{StaticInt{N}}, ::StaticInt{N}) where {N} = StaticInt{N}()
25+
Base.float(::StaticInt{N}) where {N} = Float64(N)
2526

26-
Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T <: AbstractIrrational} = promote_rule(Int, T)
27-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T <: AbstractIrrational} = promote_rule(T, Int)
27+
Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T <: Number} = promote_type(Int, T)
28+
Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T <: AbstractIrrational} = promote_type(Int, T)
29+
# Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T <: AbstractIrrational} = promote_rule(T, Int)
2830
for (S,T) [(:Complex,:Real), (:Rational, :Integer), (:(Base.TwicePrecision),:Any)]
29-
@eval Base.promote_rule(::Type{$S{T}}, ::Type{<:StaticInt}) where {T <: $T} = promote_rule($S{T}, Int)
31+
@eval Base.promote_rule(::Type{$S{T}}, ::Type{<:StaticInt}) where {T <: $T} = promote_type($S{T}, Int)
3032
end
3133
Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:StaticInt}) = Union{Nothing, Missing, Int}
32-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int)
33-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Nothing} = promote_rule(T, Int)
34-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Missing} = promote_rule(T, Int)
34+
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Union{Missing,Nothing}} = promote_type(T, Int)
35+
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Nothing} = promote_type(T, Int)
36+
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Missing} = promote_type(T, Int)
3537
for T [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any]
3638
# let S = :Any
3739
@eval begin
38-
Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: StaticInt} = promote_rule(Int, $T)
39-
Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: StaticInt} = promote_rule($T, Int)
40+
Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: StaticInt} = promote_type(Int, $T)
41+
Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: StaticInt} = promote_type($T, Int)
4042
end
4143
end
4244
Base.promote_rule(::Type{<:StaticInt}, ::Type{<:StaticInt}) = Int

src/stridelayout.jl

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ _contiguous_axis(::Any, ::Nothing) = nothing
5959
Expr(:call, Expr(:curly, :Contiguous, new_contig))
6060
end
6161

62+
# contiguous_if_one(::Contiguous{1}) = Contiguous{1}()
63+
# contiguous_if_one(::Any) = Contiguous{-1}()
64+
function contiguous_axis(::Type{R}) where {T, N, S, A <: Array{S}, R <: Base.ReinterpretArray{T, N, S, A}}
65+
isbitstype(S) ? Contiguous{1}() : nothing
66+
# contiguous_if_one(contiguous_axis(parent_type(R)))
67+
end
68+
69+
6270
"""
6371
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
6472
@@ -102,6 +110,8 @@ _contiguous_batch_size(::Any, ::Any, ::Any) = nothing
102110
end
103111
end
104112

113+
contiguous_batch_size(::Type{R}) where {T, N, S, A <: Array{S}, R <: Base.ReinterpretArray{T, N, S, A}} = ContiguousBatch{0}()
114+
105115
struct StrideRank{R} end
106116
Base.@pure StrideRank(R::NTuple{<:Any,Int}) = StrideRank{R}()
107117
_get(::StrideRank{R}) where {R} = R
@@ -158,6 +168,7 @@ _stride_rank(::Any, ::Any) = nothing
158168
Expr(:call, Expr(:curly, :StrideRank, ranktup))
159169
end
160170
stride_rank(x, i) = stride_rank(x)[i]
171+
stride_rank(::Type{R}) where {T, N, S, A <: Array{S}, R <: Base.ReinterpretArray{T, N, S, A}} = StrideRank{ntuple(identity, Val{N}())}()
161172

162173
"""
163174
is_column_major(A) -> Val{true/false}()
@@ -247,6 +258,9 @@ julia> ArrayInterface.size(A)
247258
```
248259
"""
249260
size(A) = Base.size(A)
261+
size(x::LinearAlgebra.Adjoint{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
262+
size(x::LinearAlgebra.Transpose{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
263+
250264
"""
251265
strides(A)
252266
@@ -257,6 +271,16 @@ julia> A = rand(3,4);
257271
258272
julia> ArrayInterface.strides(A)
259273
(StaticInt{1}(), 3)
274+
275+
Additionally, the behavior differs from `Base.strides` for adjoint vectors:
276+
277+
julia> x = rand(5);
278+
279+
julia> ArrayInterface.strides(x')
280+
(StaticInt{1}(), StaticInt{1}())
281+
282+
This is to support the pattern of using just the first stride for linear indexing, `x[i]`,
283+
while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`.
260284
```
261285
"""
262286
strides(A) = Base.strides(A)
@@ -272,14 +296,24 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
272296
@inline strides(A::Vector{<:Any}) = (StaticInt(1),)
273297
@inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...)
274298
@inline strides(A::AbstractArray) = _strides(A, Base.strides(A), contiguous_axis(A))
299+
300+
@inline function strides(x::LinearAlgebra.Adjoint{T,V}) where {T, V <: AbstractVector{T}}
301+
strd = stride(parent(x), One())
302+
(strd, strd)
303+
end
304+
@inline function strides(x::LinearAlgebra.Transpose{T,V}) where {T, V <: AbstractVector{T}}
305+
strd = stride(parent(x), One())
306+
(strd, strd)
307+
end
308+
275309
@generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::Contiguous{C}) where {T,N,C}
276310
if C 0 || C > N
277311
return Expr(:block, Expr(:meta,:inline), :s)
278312
end
279313
stup = Expr(:tuple)
280314
for n 1:N
281315
if n == C
282-
push!(stup.args, :(StaticInt{$(sizeof(T))}()))
316+
push!(stup.args, :(One()))
283317
else
284318
push!(stup.args, Expr(:ref, :s, n))
285319
end
@@ -290,6 +324,22 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
290324
end
291325
end
292326

327+
if VERSION v"1.6.0-DEV.1581"
328+
@generated function _strides(_::Base.ReinterpretArray{T, N, S, A, true}, s::NTuple{N}, ::Contiguous{1}) where {T, N, S, D, A <: Array{S,D}}
329+
stup = Expr(:tuple, :(One()))
330+
if D < N
331+
push!(stup.args, Expr(:call, Expr(:curly, :StaticInt, sizeof(S) ÷ sizeof(T))))
332+
end
333+
for n 2+(D < N):N
334+
push!(stup.args, Expr(:ref, :s, n))
335+
end
336+
quote
337+
$(Expr(:meta,:inline))
338+
@inbounds $stup
339+
end
340+
end
341+
end
342+
293343
@inline function offsets(x, i)
294344
inds = indices(x, i)
295345
start = known_first(inds)
@@ -313,7 +363,7 @@ end
313363
@inline strides(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = permute(strides(parent(B)), Val{I1}())
314364
@inline stride(A::AbstractArray, ::StaticInt{N}) where {N} = strides(A)[N]
315365
@inline stride(A::AbstractArray, ::Val{N}) where {N} = strides(A)[N]
316-
stride(A, i) = Base.stride(A, i)
366+
stride(A, i) = Base.stride(A, i) # for type stability
317367

318368
size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} = _size(size(parent(B)), B.indices, map(static_length, B.indices))
319369
strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} = _strides(strides(parent(B)), B.indices)
@@ -333,11 +383,16 @@ end
333383
@generated function _strides(A::Tuple{Vararg{Any,N}}, inds::I) where {N, I<:Tuple}
334384
t = Expr(:tuple)
335385
for n in 1:N
336-
if I.parameters[n] <: AbstractRange
386+
if I.parameters[n] <: AbstractUnitRange
337387
push!(t.args, Expr(:ref, :A, n))
388+
elseif I.parameters[n] <: AbstractRange
389+
push!(t.args, Expr(:call, :(*), Expr(:ref, :A, n), Expr(:call, :static_step, Expr(:ref, :inds, n))))
338390
elseif !(I.parameters[n] <: Integer)
339391
return nothing
340392
end
341393
end
342394
Expr(:block, Expr(:meta, :inline), t)
343395
end
396+
397+
398+

0 commit comments

Comments
 (0)