Skip to content

Commit b7bafdb

Browse files
committed
Added bitshift methods for Static, improved inferabiity of transposed dimnames.
1 parent df479d1 commit b7bafdb

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

src/dimensions.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,20 @@ Return the names of the dimensions for `x`.
3030
end
3131
end
3232
@inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}}
33-
return _transpose_dimnames(dimnames(parent_type(T)))
33+
return _transpose_dimnames(Val(dimnames(parent_type(T))))
3434
end
35-
_transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x))
36-
_transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x))
35+
# inserting the Val here seems to help inferability; I got a test failure without it.
36+
function _transpose_dimnames(::Val{S}) where {S}
37+
if length(S) == 1
38+
(:_, first(S))
39+
elseif length(S) == 2
40+
(last(S), first(S))
41+
else
42+
throw("Can't transpose $S of dim $(length(S)).")
43+
end
44+
end
45+
@inline _transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x))
46+
@inline _transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x))
3747

3848
@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}}
3949
return map(i -> dimnames(parent_type(T), i), I)
@@ -143,6 +153,8 @@ julia> ArrayInterface.size(A)
143153
"""
144154
size(A) = Base.size(A)
145155
size(A, d) = Base.size(A, to_dims(A, d))
156+
size(x::LinearAlgebra.Adjoint{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
157+
size(x::LinearAlgebra.Transpose{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
146158

147159
"""
148160
axes(A, d)

src/static.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ end
8787
for f [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :()]
8888
@eval @generated Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = Expr(:call, Expr(:curly, :StaticInt, $f(M, N)))
8989
end
90+
for f [:(<<), :(>>), :(>>>)]
91+
@eval begin
92+
@inline Base.$f(::StaticInt{M}, x::UInt) where {M} = $f(M, x)
93+
@inline Base.$f(x::Integer, ::StaticInt{M}) where {M} = $f(x, M)
94+
end
95+
end
9096
for f [:(==), :(!=), :(<), :(), :(>), :()]
9197
@eval begin
9298
@inline Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = $f(M, N)

src/stridelayout.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -249,24 +249,6 @@ permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}(
249249
Expr(:block, Expr(:meta, :inline), t)
250250
end
251251

252-
"""
253-
size(A)
254-
255-
Returns the size of `A`. If the size of any axes are known at compile time,
256-
these should be returned as `Static` numbers. For example:
257-
```julia
258-
julia> using StaticArrays, ArrayInterface
259-
260-
julia> A = @SMatrix rand(3,4);
261-
262-
julia> ArrayInterface.size(A)
263-
(StaticInt{3}(), StaticInt{4}())
264-
```
265-
"""
266-
size(A) = Base.size(A)
267-
size(x::LinearAlgebra.Adjoint{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
268-
size(x::LinearAlgebra.Transpose{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
269-
270252
"""
271253
strides(A)
272254

0 commit comments

Comments
 (0)