Skip to content

Commit 6f7940d

Browse files
committed
fix some of the issues found by JET.jl
1 parent 8e5bc0b commit 6f7940d

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

src/named_array_partition.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
NamedArrayPartition(; kwargs...)
3-
NamedArrayPartition(x::NamedTuple)
3+
NamedArrayPartition(x::NamedTuple)
44
55
Similar to an `ArrayPartition` but the individual arrays can be accessed via the
66
constructor-specified names. However, unlike `ArrayPartition`, each individual array
@@ -22,7 +22,7 @@ function NamedArrayPartition(x::NamedTuple)
2222
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
2323
end
2424

25-
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
25+
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
2626
# fields except through `getfield` and accessor functions.
2727
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
2828

@@ -53,7 +53,7 @@ end
5353
function Base.similar(
5454
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
5555
NamedArrayPartition(
56-
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
56+
similar(getfield(A, :array_partition), T, S, R...), getfield(A, :names_to_indices))
5757
end
5858

5959
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
@@ -68,7 +68,7 @@ function Base.getproperty(x::NamedArrayPartition, s::Symbol)
6868
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
6969
end
7070

71-
# this enables x.s = some_array.
71+
# this enables x.s = some_array.
7272
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
7373
index = getproperty(getfield(x, :names_to_indices), s)
7474
ArrayPartition(x).x[index] .= v

src/utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ function recursivefill!(b::AbstractArray{T, N},
119119
a::T2) where {T <: StaticArraysCore.SArray,
120120
T2 <: Union{Number, Bool}, N}
121121
@inbounds for i in eachindex(b)
122-
b[i] = fill(a, typeof(b[i]))
122+
# Preserve static array shape while replacing all entries with the scalar
123+
b[i] = map(_ -> a, b[i])
123124
end
124125
end
125126

@@ -128,7 +129,8 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N},
128129
T2 <: Union{Number, Bool}, N}
129130
@inbounds for b in bs, i in eachindex(b)
130131

131-
b[i] = fill(a, typeof(b[i]))
132+
# Preserve static array shape while replacing all entries with the scalar
133+
b[i] = map(_ -> a, b[i])
132134
end
133135
end
134136

src/vector_of_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,8 @@ function Base.view(A::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
669669
J = map(i -> Base.unalias(A, i), to_indices(A, I))
670670
elseif length(I) == 2 && (I[1] == Colon() || I[1] == 1)
671671
J = map(i -> Base.unalias(A, i), to_indices(A, Base.tail(I)))
672+
else
673+
J = map(i -> Base.unalias(A, i), to_indices(A, I))
672674
end
673675
@boundscheck checkbounds(A, J...)
674676
SubArray(A, J)
@@ -939,6 +941,7 @@ end
939941

940942
struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays
941943
VectorOfArrayStyle{N}(::Val{N}) where {N} = VectorOfArrayStyle{N}()
944+
VectorOfArrayStyle(::Val{N}) where {N} = VectorOfArrayStyle{N}()
942945

943946
# The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle.
944947
Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a

0 commit comments

Comments
 (0)